﻿using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Text;
using System.Text.RegularExpressions;
using System.Xml.Linq;

namespace GroupedNativeMethodsGenerator;

[Generator(LanguageNames.CSharp)]
public partial class GroupedNativeMethodsGenerator : IIncrementalGenerator
{
    public void Initialize(IncrementalGeneratorInitializationContext context)
    {
        context.RegisterPostInitializationOutput(ctx =>
        {
            ctx.AddSource("GroupedNativeMethodsGenerator.Attribute.cs", """
using System;

namespace GroupedNativeMethodsGenerator
{
    [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)]
    internal sealed class GroupedNativeMethodsAttribute : Attribute
    {
        public string RemovePrefix { get; }
        public string RemoveSuffix { get; }
        public bool RemoveUntilTypeName { get; }
        public bool FixMethodName { get; }

        public GroupedNativeMethodsAttribute(string removePrefix = "", string removeSuffix = "", bool removeUntilTypeName = true, bool fixMethodName = true)
        {
            this.RemovePrefix = removePrefix;
            this.RemoveSuffix = removeSuffix;
            this.RemoveUntilTypeName = removeUntilTypeName;
            this.FixMethodName = fixMethodName;
        }
    }
}
""");
        });

        var source = context.SyntaxProvider.ForAttributeWithMetadataName("GroupedNativeMethodsGenerator.GroupedNativeMethodsAttribute",
            (node, token) => node is ClassDeclarationSyntax,
            (ctx, token) => ctx);

        context.RegisterSourceOutput(source, Emit);
    }

    static void Emit(SourceProductionContext context, GeneratorAttributeSyntaxContext source)
    {
        var typeSymbol = (INamedTypeSymbol)source.TargetSymbol;
        var typeNode = (TypeDeclarationSyntax)source.TargetNode;

        var ns = typeSymbol.ContainingNamespace.IsGlobalNamespace
            ? ""
            : $"namespace {typeSymbol.ContainingNamespace}\n{{";

        var accessibility = typeSymbol.DeclaredAccessibility == Accessibility.Public ? "public" : "internal";

        var fullType = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
            .Replace("global::", "")
            .Replace("<", "_")
            .Replace(">", "_");

        var grouped = typeSymbol.GetMembers().OfType<IMethodSymbol>()
            .Where(x => x.Parameters.Length != 0)
            .Where(x => x.Parameters[0].Type is IPointerTypeSymbol t && (t.PointedAtType.SpecialType is SpecialType.None) && t.PointedAtType.TypeKind != TypeKind.Pointer)
            .ToLookup(x =>
            {
                return ((IPointerTypeSymbol)x.Parameters[0].Type).PointedAtType.ToDisplayString();
            });

        var libTypeName = typeSymbol.Name;
        var removePrefix = (string)source.Attributes[0].ConstructorArguments[0].Value!;
        var removeSuffix = (string)source.Attributes[0].ConstructorArguments[1].Value!;
        var removeUntilTypeName = (bool)source.Attributes[0].ConstructorArguments[2].Value!;
        var fixMethodName = (bool)source.Attributes[0].ConstructorArguments[3].Value!;

        var code = new StringBuilder();

        code.AppendLine($$"""
// <auto-generated/>
#nullable enable
#pragma warning disable CS8600
#pragma warning disable CS8601
#pragma warning disable CS8602
#pragma warning disable CS8603
#pragma warning disable CS8604

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

{{ns}}

    {{accessibility}} static unsafe class {{typeSymbol.Name}}GroupingExtensions
    {
""");
        foreach (var g in grouped)
        {
            code.AppendLine($"#region {g.Key}({g.Count()})");
            code.AppendLine();
            foreach (var item in g)
            {
                var firstArgument = item.Parameters[0];
                var ret = item.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
                var requireRet = ret == "void" ? "" : "return ";

                string? summaryComment = null;
                var docComment = item.GetDocumentationCommentXml();
                if (!string.IsNullOrEmpty(docComment))
                {
                    var xElem = XElement.Parse(docComment);
                    summaryComment = "/// " + xElem.Element("summary").ToString().Replace("\r\n", " ").Replace("\r", " ").Replace("\n", " ");
                }

                var convertedMethodName = ConvertMethodName(((IPointerTypeSymbol)firstArgument.Type).PointedAtType.Name, item.Name, removePrefix, removeSuffix, removeUntilTypeName, fixMethodName);
                var pointedType = ((IPointerTypeSymbol)firstArgument.Type).PointedAtType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
                var parameterPairs = string.Join("", item.Parameters.Skip(1).Select(x => $", {x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} @{x.Name}"));
                var parameterNames = string.Join("", item.Parameters.Skip(1).Select(x => $", @{x.Name}"));
                var obsoleteAttribute = item.GetAttributes().SingleOrDefault(x => x.AttributeClass?.Name == nameof(ObsoleteAttribute));

                if (summaryComment != null)
                {
                    code.AppendLine("        " + summaryComment);
                }
                if (obsoleteAttribute != null)
                {
                    code.AppendLine("        " + ObsoleteAttributeToString(obsoleteAttribute));
                }
                code.AppendLine("        [MethodImpl(MethodImplOptions.AggressiveInlining)]");
                code.AppendLine($"        public static {ret} {convertedMethodName}(this ref {pointedType} @{firstArgument.Name}{parameterPairs})");
                code.AppendLine("        {");
                code.AppendLine($"            {requireRet}{libTypeName}.{item.Name}(({pointedType}*)Unsafe.AsPointer(ref @{firstArgument.Name}){parameterNames});");
                code.AppendLine("        }");
                code.AppendLine("");
            }
            code.AppendLine("#endregion");
            code.AppendLine();
        }

        code.AppendLine("    }");
        if (ns != "")
        {
            code.AppendLine("}");
        }

        context.AddSource($"{fullType}.GroupedNativeMethods.g.cs", code.ToString());
    }

    static string ConvertMethodName(string typeName, string methodName, string removePrefix, string removeSuffix, bool removeUntilTypeName, bool fixMethodName)
    {
        if (!fixMethodName) return methodName;

        if (removeUntilTypeName)
        {
            if (TryTrimPrefix(methodName, typeName, out var trimmed))
            {
                methodName = trimmed;
                goto FINAL;
            }
            if (TryTrimPrefix(methodName, ToSnakeCase(typeName), out trimmed))
            {
                methodName = trimmed;
                goto FINAL;
            }
        }

        if (!string.IsNullOrEmpty(removePrefix))
        {
            methodName = Regex.Replace(methodName, $"^{Regex.Escape(removePrefix)}", "");
        }

    FINAL:

        if (!string.IsNullOrEmpty(removeSuffix))
        {
            methodName = Regex.Replace(methodName, $"{Regex.Escape(removeSuffix)}$", "");
        }

        methodName = methodName.Trim('_', ' ');

        return ToCamelCase(methodName);
    }

    static string ObsoleteAttributeToString(AttributeData obsoleteAttribute)
    {
        if (obsoleteAttribute.ConstructorArguments.IsEmpty && obsoleteAttribute.NamedArguments.IsEmpty)
        {
            return "[Obsolete]";
        }

        var ctorArgs = obsoleteAttribute.ConstructorArguments.Select(x => x.ToCSharpString());
        var namedArgs = obsoleteAttribute.NamedArguments.Select(x => $"{x.Key} = {x.Value.ToCSharpString()}");

        return $"[Obsolete({string.Join(", ", ctorArgs.Concat(namedArgs))})]";
    }

    static bool TryTrimPrefix(string value, string prefix, out string result)
    {
        var match = value.IndexOf(prefix, StringComparison.Ordinal);
        if (match > -1)
        {
            result = value.Substring(match + prefix.Length).Trim(' ', '_');
            return result.Length > 0;
        }
        result = default!;
        return false;
    }

    static string ToCamelCase(string snakeCase)
    {
        var split = snakeCase.Split('_');
        return string.Concat(split.Select(x =>
        {
            return x.Length switch
            {
                0 => x,
                1 => char.ToUpper(x[0]).ToString(),
                _ => char.ToUpper(x[0]) + x.Substring(1)
            };
        }));
    }

    static string ToSnakeCase(string camelCase)
    {
        Span<char> buffer = stackalloc char[camelCase.Length * 2];
        var written = 0;
        buffer[written++] = char.ToLowerInvariant(camelCase[0]);

        for (var i = 1; i < camelCase.Length; ++i)
        {
            var ch = camelCase[i];
            if (char.IsUpper(ch))
            {
                buffer[written++] = '_';
                buffer[written++] = char.ToLowerInvariant(ch);
            }
            else
            {
                buffer[written++] = ch;
            }
        }
        return buffer.Slice(0, written).ToString();
    }
}